{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "984d4ffc",
   "metadata": {},
   "source": [
    "# Lesson 2: Image Segmentation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e3d21427",
   "metadata": {},
   "source": [
    "<p style=\"background-color:#fff6e4; padding:15px; border-width:3px; border-color:#f5ecda; border-style:solid; border-radius:6px\"> ⏳ <b>Note <code>(Kernel Starting)</code>:</b> This notebook takes about 30 seconds to be ready to use. You may start and watch the video while you wait.</p>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "06744c6f",
   "metadata": {},
   "source": [
    "* In this classroom, the libraries have been already installed for you.\n",
    "* If you would like to run this code on your own machine, you need to install the following:\n",
    "    ```\n",
    "    !pip install ultralytics torch\n",
    "    ```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "150a0374",
   "metadata": {},
   "source": [
    "### Load the sample image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c49c6b40-97fb-4f08-a8e0-4054e29d42e7",
   "metadata": {
    "height": 63
   },
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "raw_image = Image.open(\"dogs.jpg\")\n",
    "raw_image"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9b9b59e0",
   "metadata": {},
   "source": [
    ">Note: the images referenced in this notebook have already been uploaded to the Jupyter directory, in this classroom, for your convenience. For further details, please refer to the **Appendix** section located at the end of the lessons."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "87d2d941",
   "metadata": {},
   "source": [
    "* Resize the image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4517328e-612e-4320-8d09-e15cbf37daa7",
   "metadata": {
    "height": 63
   },
   "outputs": [],
   "source": [
    "from utils import resize_image\n",
    "resized_image = resize_image(raw_image, input_size=1024)\n",
    "resized_image"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bba1103c",
   "metadata": {},
   "source": [
    "### Import and prepare the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e0714f3-c394-49c8-8673-5f005a8ce49f",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d48635d4",
   "metadata": {
    "height": 63
   },
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "device"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7496aabd",
   "metadata": {},
   "source": [
    "Info about [torch](https://pytorch.org/)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5ecc9ea-b834-4ad9-adb2-7d0c853d8d73",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "from ultralytics import YOLO\n",
    "model = YOLO('./FastSAM.pt')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "893be50b",
   "metadata": {},
   "source": [
    "Info about ['FastSAM'](https://docs.ultralytics.com/models/fast-sam/)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e2f43d87",
   "metadata": {},
   "source": [
    "### Use the model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3373596e",
   "metadata": {},
   "source": [
    ">Note: ```utils``` is an additional file containing the methods that have been already developed for you to be used in this classroom. \n",
    "For further details, please refer to the **Appendix** section located at the end of the lessons."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6a83331-2775-4d0d-abb1-53648e1e6a64",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "from utils import show_points_on_image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7278e50f-5b08-4f14-b62b-3b6c670096da",
   "metadata": {
    "height": 63
   },
   "outputs": [],
   "source": [
    "# Define the coordinates for the point in the image\n",
    "# [x_axis, y_axis]\n",
    "input_points = [ [350, 450 ] ]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f6e41f2-744a-4971-8610-0287a329341c",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "input_labels = [1] # positive point"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e724e32-4f35-4184-9dfc-4d3f8cafdc3c",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "# Function written in the utils file\n",
    "show_points_on_image(resized_image, input_points)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7532bc8d-ff31-4de3-9d71-98ed494933d5",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "# Run the model\n",
    "results = model(resized_image, device=device, retina_masks=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7761784b",
   "metadata": {},
   "source": [
    "* Filter the mask based on the point defined before."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d79f594-0347-4cf9-b390-8223b553ed17",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "from utils import format_results, point_prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d88f329-6ae2-41ee-8128-6fe9b739b0d2",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "results = format_results(results[0], 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37eb6dc8-8153-4eee-9e2d-9f3d330c2e52",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "# Generate the masks\n",
    "masks, _ = point_prompt(results, input_points, input_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5dc12d9-6d4c-48b5-b59f-a096530199bc",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "from utils import show_masks_on_image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6d3f983-34c8-4382-9c00-5094693f26b3",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "# Visualize the generated masks\n",
    "show_masks_on_image(resized_image, [masks])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c7308c2",
   "metadata": {},
   "source": [
    "* Define 'semantic masks' - two points to be masked."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8bad722-c4ce-400e-bc2a-989842bd6e5e",
   "metadata": {
    "height": 63
   },
   "outputs": [],
   "source": [
    "# Specify two points in the same image\n",
    "# [x_axis, y_axis]\n",
    "input_points = [ [350, 450], [620, 450] ]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1136850-5661-4136-baec-58d00fbc4f74",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "# Specify both points as \"positive prompt\"\n",
    "input_labels = [1 , 1] # both positive points"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0200106e-b203-4f1b-9ad0-d21a2efcc60d",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "# Visualize the points defined before\n",
    "show_points_on_image(resized_image, input_points)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e97e7bf-e286-490d-98b0-c4eec087000e",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "# Run the model\n",
    "results = model(resized_image, device=device, retina_masks=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c60fb80f-10c0-4ba9-8734-c06ae9409017",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "results = format_results(results[0], 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c46b0b4c-49d0-4392-b479-2c8bf03b767c",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "# Generate the masks\n",
    "masks, _ = point_prompt(results, input_points, input_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2217634e-8f79-42ae-97ca-4b55021bf759",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "# Visualize the generated masks\n",
    "show_masks_on_image(resized_image, [masks])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1304fc0e",
   "metadata": {},
   "source": [
    ">Note: Please note that the results obtained from running this notebook may vary slightly from those demonstrated by the instructor in the video. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6ec6f9b8",
   "metadata": {},
   "source": [
    "* Identify subsections of the image by adding a **negative prompt**."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "751c9d1d-27aa-4225-a78e-1b2f74b1eb46",
   "metadata": {
    "height": 63
   },
   "outputs": [],
   "source": [
    "# Define the coordinates for the regions to be masked\n",
    "# [x_axis, y_axis]\n",
    "input_points = [ [350, 450], [400, 300]  ]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1de18e03-8de4-404c-be78-a753e411435c",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "input_labels = [1, 0] # positive prompt, negative prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33c1ca7a-22f1-4b75-8b1c-e232614d8139",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "# Visualize the points defined above\n",
    "show_points_on_image(resized_image, input_points, input_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9ce6bd8c",
   "metadata": {},
   "source": [
    ">Note: From the image above, the red star indicates the negative prompt and the green star the positive prompt."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a6f408a-3b66-4b17-91ab-a03c5d77f57c",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "# Run the model\n",
    "results = model(resized_image, device=device, retina_masks=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4734d092-37ef-4570-8d32-7cdfc475b110",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "results = format_results(results[0], 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cbf12e0f-142c-4c24-bd2f-6c5bbb976c21",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "# Generate the masks\n",
    "masks, _ = point_prompt(results, input_points, input_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79c2b01b-dd4c-426e-9b0e-edd6330be670",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "# Visualize the generated masks\n",
    "show_masks_on_image(resized_image, [masks])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a60a4807",
   "metadata": {},
   "source": [
    ">Note: From the image above, only the jacket, from the dog in the left, was segmented, so, it is following the command given by the positive prompt!"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2ae4e7a2",
   "metadata": {},
   "source": [
    "### Prompting with bounding boxes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "401a0557-42b8-4d3d-99f7-96ba22860b0c",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "from utils import box_prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2de8a106-4035-4e8f-a0b2-3cc65ae4afed",
   "metadata": {
    "height": 63
   },
   "outputs": [],
   "source": [
    "# Set the bounding box coordinates\n",
    "# [xmin, ymin, xmax, ymax]\n",
    "input_boxes = [530, 180, 780, 600]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed7b3975",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "from utils import show_boxes_on_image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f9b80ee-c1a8-440c-b120-0adc416e2144",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "# Visualize the bounding box defined with the coordinates above\n",
    "show_boxes_on_image(resized_image, [input_boxes])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "124071f3",
   "metadata": {},
   "source": [
    "* Now, try to isolate the mask from the total output of the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "608b7378-97b7-42fe-b2d0-b52d30945311",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "from utils import box_prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0324ec8-d2ae-46ea-9619-453471cca9db",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "results = model(resized_image, device=device, retina_masks=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62055a75-7c2a-4b99-b473-2788d6be964e",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "# Generate the masks\n",
    "masks = results[0].masks.data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f52345d-3887-4da6-88bc-54303d4ef98c",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "masks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db95c252-8e98-467c-a73c-9fcdd2dff0cc",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "# Convert to True/False boolean mask\n",
    "masks = masks > 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57c0589c-d983-48b3-a69b-55ebd69ade8b",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "masks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1dd714c-1f3b-4532-a1d9-3361f53ce829",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "masks, _ = box_prompt(masks, input_boxes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c3a8bb7-c2f4-4e1e-a653-734bb97db5f2",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "# Visualize the masks\n",
    "show_masks_on_image(resized_image, [masks])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f2f8ae5-83d1-4695-a89b-f39fa62e4f3a",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "# Print the segmentation mask, but in its raw format\n",
    "masks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2cb1919a-59d5-43e8-a7ef-ac6bd12bb5cf",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "# To visualize, import matplotlib\n",
    "from matplotlib import pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d022c91-c2aa-4b93-8676-3d4739293c5f",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "# Plot the binary mask as an image\n",
    "plt.imshow(masks, cmap='gray')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a8c368ba",
   "metadata": {},
   "source": [
    "### Try yourself! \n",
    "Try the image segmentation explained before with your own images."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0e00a69-72ff-45c2-a6da-1cb158c118f1",
   "metadata": {
    "height": 97
   },
   "outputs": [],
   "source": [
    "# To start opening images, upload your own or use the sample images we've uploaded, for example: younes.png\n",
    "# The image younes.png is already uploaded in this classroom\n",
    "raw_image = Image.open('younes.png')\n",
    "raw_image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53952bf6",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "# Resize image\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c4e6885",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "# Define the coordinates for the point: [x_axis, y_axis]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57087862",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "# Define the positive or negative prompt\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "446d53ff",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "# show_points_on_image(resized_image, input_points)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bb6f774f",
   "metadata": {},
   "source": [
    "### Additional Resources\n",
    "\n",
    "* For more on how to use [Comet](https://www.comet.com/site/?utm_source=dlai&utm_medium=course&utm_campaign=prompt_engineering_for_vision_models&utm_content=dlai_L2) for experiment tracking, check out this [Quickstart Guide](https://colab.research.google.com/drive/1jj9BgsFApkqnpPMLCHSDH-5MoL_bjvYq?usp=sharing) and the [Comet Docs](https://www.comet.com/docs/v2/?utm_source=dlai&utm_medium=course&utm_campaign=prompt_engineering_for_vision_models&utm_content=dlai_L2).\n",
    "* This course was based off a set of two blog articles from Comet. Explore them here for more on how to use newer versions of Stable Diffusion in this pipeline, additional tricks to improve your inpainting results, and a breakdown of the pipeline architecture:\n",
    "  * [SAM + Stable Diffusion for Text-to-Image Inpainting](https://www.comet.com/site/blog/sam-stable-diffusion-for-text-to-image-inpainting/?utm_source=dlai&utm_medium=course&utm_campaign=prompt_engineering_for_vision_models&utm_content=dlai_L2)\n",
    "  * [Image Inpainting for SDXL 1.0 Base Model + Refiner](https://www.comet.com/site/blog/image-inpainting-for-sdxl-1-0-base-refiner/?utm_source=dlai&utm_medium=course&utm_campaign=prompt_engineering_for_vision_models&utm_content=dlai_L2)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
